{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GYEYC1orNFdN"
   },
   "source": [
    "# Ungraded Lab: Using Convolutions with LSTMs\n",
    "\n",
    "Welcome to the final week of this course! In this lab, you will build upon the RNN models you built last week and append a convolution layer to it. As you saw in previous courses, convolution filters can also capture features from sequences so it's good to try them out when exploring model architectures. Let's begin!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vHuz2uNTNMnU"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BOjujz601HcS"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dsCvMZCCNON1"
   },
   "source": [
    "## Utilities\n",
    "\n",
    "You will be plotting the MAE and loss later so the `plot_series()` is extended to have more labelling functionality. The utilities for generating the synthetic data is the same as the previous labs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Zswl7jRtGzkk"
   },
   "outputs": [],
   "source": [
    "def plot_series(x, y, format=\"-\", start=0, end=None, \n",
    "                title=None, xlabel=None, ylabel=None, legend=None ):\n",
    "    \"\"\"\n",
    "    Visualizes time series data\n",
    "\n",
    "    Args:\n",
    "      x (array of int) - contains values for the x-axis\n",
    "      y (array of int or tuple of arrays) - contains the values for the y-axis\n",
    "      format (string) - line style when plotting the graph\n",
    "      start (int) - first time step to plot\n",
    "      end (int) - last time step to plot\n",
    "      title (string) - title of the plot\n",
    "      xlabel (string) - label for the x-axis\n",
    "      ylabel (string) - label for the y-axis\n",
    "      legend (list of strings) - legend for the plot\n",
    "    \"\"\"\n",
    "\n",
    "    # Setup dimensions of the graph figure\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    \n",
    "    # Check if there are more than two series to plot\n",
    "    if type(y) is tuple:\n",
    "\n",
    "      # Loop over the y elements\n",
    "      for y_curr in y:\n",
    "\n",
    "        # Plot the x and current y values\n",
    "        plt.plot(x[start:end], y_curr[start:end], format)\n",
    "\n",
    "    else:\n",
    "      # Plot the x and y values\n",
    "      plt.plot(x[start:end], y[start:end], format)\n",
    "\n",
    "    # Label the x-axis\n",
    "    plt.xlabel(xlabel)\n",
    "\n",
    "    # Label the y-axis\n",
    "    plt.ylabel(ylabel)\n",
    "\n",
    "    # Set the legend\n",
    "    if legend:\n",
    "      plt.legend(legend)\n",
    "\n",
    "    # Set the title\n",
    "    plt.title(title)\n",
    "\n",
    "    # Overlay a grid on the graph\n",
    "    plt.grid(True)\n",
    "\n",
    "    # Draw the graph on screen\n",
    "    plt.show()\n",
    "\n",
    "def trend(time, slope=0):\n",
    "    \"\"\"\n",
    "    Generates synthetic data that follows a straight line given a slope value.\n",
    "\n",
    "    Args:\n",
    "      time (array of int) - contains the time steps\n",
    "      slope (float) - determines the direction and steepness of the line\n",
    "\n",
    "    Returns:\n",
    "      series (array of float) - measurements that follow a straight line\n",
    "    \"\"\"\n",
    "\n",
    "    # Compute the linear series given the slope\n",
    "    series = slope * time\n",
    "\n",
    "    return series\n",
    "\n",
    "def seasonal_pattern(season_time):\n",
    "    \"\"\"\n",
    "    Just an arbitrary pattern, you can change it if you wish\n",
    "    \n",
    "    Args:\n",
    "      season_time (array of float) - contains the measurements per time step\n",
    "\n",
    "    Returns:\n",
    "      data_pattern (array of float) -  contains revised measurement values according \n",
    "                                  to the defined pattern\n",
    "    \"\"\"\n",
    "\n",
    "    # Generate the values using an arbitrary pattern\n",
    "    data_pattern = np.where(season_time < 0.4,\n",
    "                    np.cos(season_time * 2 * np.pi),\n",
    "                    1 / np.exp(3 * season_time))\n",
    "    \n",
    "    return data_pattern\n",
    "\n",
    "def seasonality(time, period, amplitude=1, phase=0):\n",
    "    \"\"\"\n",
    "    Repeats the same pattern at each period\n",
    "\n",
    "    Args:\n",
    "      time (array of int) - contains the time steps\n",
    "      period (int) - number of time steps before the pattern repeats\n",
    "      amplitude (int) - peak measured value in a period\n",
    "      phase (int) - number of time steps to shift the measured values\n",
    "\n",
    "    Returns:\n",
    "      data_pattern (array of float) - seasonal data scaled by the defined amplitude\n",
    "    \"\"\"\n",
    "    \n",
    "    # Define the measured values per period\n",
    "    season_time = ((time + phase) % period) / period\n",
    "\n",
    "    # Generates the seasonal data scaled by the defined amplitude\n",
    "    data_pattern = amplitude * seasonal_pattern(season_time)\n",
    "\n",
    "    return data_pattern\n",
    "\n",
    "def noise(time, noise_level=1, seed=None):\n",
    "    \"\"\"Generates a normally distributed noisy signal\n",
    "\n",
    "    Args:\n",
    "      time (array of int) - contains the time steps\n",
    "      noise_level (float) - scaling factor for the generated signal\n",
    "      seed (int) - number generator seed for repeatability\n",
    "\n",
    "    Returns:\n",
    "      noise (array of float) - the noisy signal\n",
    "    \"\"\"\n",
    "\n",
    "    # Initialize the random number generator\n",
    "    rnd = np.random.RandomState(seed)\n",
    "\n",
    "    # Generate a random number for each time step and scale by the noise level\n",
    "    noise = rnd.randn(len(time)) * noise_level\n",
    "    \n",
    "    return noise"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5eI0WvrTNT6d"
   },
   "source": [
    "## Generate the Synthetic Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dtMZfbuxNtLA"
   },
   "outputs": [],
   "source": [
    "# Parameters\n",
    "time = np.arange(4 * 365 + 1, dtype=\"float32\")\n",
    "baseline = 10\n",
    "amplitude = 40\n",
    "slope = 0.05\n",
    "noise_level = 5\n",
    "\n",
    "# Create the series\n",
    "series = baseline + trend(time, slope) + seasonality(time, period=365, amplitude=amplitude)\n",
    "\n",
    "# Update with noise\n",
    "series += noise(time, noise_level, seed=42)\n",
    "\n",
    "# Plot the results\n",
    "plot_series(time, series, xlabel='Time', ylabel='Value')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JXUYdSVmNYvN"
   },
   "source": [
    "## Split the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jvvA0or1OA5-"
   },
   "outputs": [],
   "source": [
    "# Define the split time\n",
    "split_time = 1000\n",
    "\n",
    "# Get the train set \n",
    "time_train = time[:split_time]\n",
    "x_train = series[:split_time]\n",
    "\n",
    "# Get the validation set\n",
    "time_valid = time[split_time:]\n",
    "x_valid = series[split_time:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UsuTYUWuNam3"
   },
   "source": [
    "## Prepare Features and Labels\n",
    "\n",
    "As mentioned in the lectures, you can experiment with different batch sizing here and see how it affects your results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "A-sg6XuP8C00"
   },
   "outputs": [],
   "source": [
    "# Parameters\n",
    "window_size = 20\n",
    "batch_size = 16\n",
    "shuffle_buffer_size = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4sTTIOCbyShY"
   },
   "outputs": [],
   "source": [
    "def windowed_dataset(series, window_size, batch_size, shuffle_buffer):\n",
    "    \"\"\"Generates dataset windows\n",
    "\n",
    "    Args:\n",
    "      series (array of float) - contains the values of the time series\n",
    "      window_size (int) - the number of time steps to average\n",
    "      batch_size (int) - the batch size\n",
    "      shuffle_buffer(int) - buffer size to use for the shuffle method\n",
    "\n",
    "    Returns:\n",
    "      dataset (TF Dataset) - TF Dataset containing time windows\n",
    "    \"\"\"\n",
    "\n",
    "    # Add an axis for the feature dimension of RNN layers\n",
    "    series = tf.expand_dims(series, axis=-1)\n",
    "    \n",
    "    # Generate a TF Dataset from the series values\n",
    "    dataset = tf.data.Dataset.from_tensor_slices(series)\n",
    "    \n",
    "    # Window the data but only take those with the specified size\n",
    "    dataset = dataset.window(window_size + 1, shift=1, drop_remainder=True)\n",
    "    \n",
    "    # Flatten the windows by putting its elements in a single batch\n",
    "    dataset = dataset.flat_map(lambda window: window.batch(window_size + 1))\n",
    "\n",
    "    # Create tuples with features and labels \n",
    "    dataset = dataset.map(lambda window: (window[:-1], window[-1]))\n",
    "\n",
    "    # Shuffle the windows\n",
    "    dataset = dataset.shuffle(shuffle_buffer)\n",
    "    \n",
    "    # Create batches of windows\n",
    "    dataset = dataset.batch(batch_size)\n",
    "    \n",
    "    # Optimize the dataset for training\n",
    "    dataset = dataset.cache().prefetch(1)\n",
    "    \n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "UrG17CqN8ql0"
   },
   "outputs": [],
   "source": [
    "# Generate the dataset windows\n",
    "train_set = windowed_dataset(x_train, window_size, batch_size, shuffle_buffer_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EQHk3KPuPbQU"
   },
   "source": [
    "## Build the Model\n",
    "\n",
    "Here is the model architecture you will be using. It is very similar to the last RNN you built but with the [Conv1D](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv1D) layer at the input. One important [argument](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv1D#args) here is the `padding`. For time series data, it is good practice to not let computations for a particular time step to be affected by values into the future. Here is one way of looking at it:\n",
    "\n",
    "* Let's say you have a small time series window with these values: `[1, 2, 3, 4, 5]`. This means the value `1` is at `t=0`, `2` is at `t=1`, etc.\n",
    "* If you have a 1D kernel of size `3`, then the first convolution will be for the values at `[1, 2, 3]` which are values for `t=0` to `t=2`.\n",
    "* When you pass this to the first timestep of the `LSTM` after the convolution, it means that the value at `t=0` of the LSTM depends on `t=1` and `t=2` which are values into the future.\n",
    "* For time series data, you want computations to only rely on current and previous time steps.\n",
    "* One way to do that is to pad the array depending on the kernel size and stride. For a kernel size of 3 and stride of 1, the window can be padded as such: `[0, 0, 1, 2, 3, 4, 5]`. `1` is still at `t=0` and two zeroes are prepended to simulate values in the past.\n",
    "* This way, the first stride will be at `[0, 0, 1]` and this does not contain any future values when it is passed on to subsequent layers.\n",
    "\n",
    "The `Conv1D` layer does this kind of padding by setting `padding=causal` and you'll see that below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Yqc2GTsps0qf"
   },
   "outputs": [],
   "source": [
    "# Reset states generated by Keras\n",
    "tf.keras.backend.clear_session()\n",
    "\n",
    "# Build the model\n",
    "model = tf.keras.models.Sequential([\n",
    "    tf.keras.Input(shape=(window_size,1)),\n",
    "    tf.keras.layers.Conv1D(filters=64, kernel_size=3,\n",
    "                      strides=1, padding=\"causal\",\n",
    "                      activation=\"relu\"),\n",
    "    tf.keras.layers.LSTM(64, return_sequences=True),\n",
    "    tf.keras.layers.LSTM(64),\n",
    "    tf.keras.layers.Dense(1),\n",
    "    tf.keras.layers.Lambda(lambda x: x * 400)\n",
    "])\n",
    "\n",
    "# Print the model summary\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cQdIfpeBPmTm"
   },
   "source": [
    "## Tune the Learning Rate\n",
    "\n",
    "In the previous labs, you are using different models for tuning and training. That is a valid approach but you can also use the same model for both. Before tuning, you can use the [`get_weights()`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#get_weights) method so you can reset it later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-Fo71QNs7UJY"
   },
   "outputs": [],
   "source": [
    "# Get initial weights\n",
    "init_weights = model.get_weights()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bcVbJDQNAe39"
   },
   "source": [
    "After that, you can tune the model as usual."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Dcbk9QMEKsLi"
   },
   "outputs": [],
   "source": [
    "# Set the learning rate scheduler\n",
    "lr_schedule = tf.keras.callbacks.LearningRateScheduler(\n",
    "    lambda epoch: 1e-8 * 10**(epoch / 20))\n",
    "\n",
    "# Initialize the optimizer\n",
    "optimizer = tf.keras.optimizers.SGD(momentum=0.9)\n",
    "\n",
    "# Set the training parameters\n",
    "model.compile(loss=tf.keras.losses.Huber(), optimizer=optimizer)\n",
    "\n",
    "# Train the model\n",
    "history = model.fit(train_set, epochs=100, callbacks=[lr_schedule])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MjTvASUns0qh"
   },
   "outputs": [],
   "source": [
    "# Define the learning rate array\n",
    "lrs = 1e-8 * (10 ** (np.arange(100) / 20))\n",
    "\n",
    "# Set the figure size\n",
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "# Set the grid\n",
    "plt.grid(True)\n",
    "\n",
    "# Plot the loss in log scale\n",
    "plt.semilogx(lrs, history.history[\"loss\"])\n",
    "\n",
    "# Increase the tickmarks size\n",
    "plt.tick_params('both', length=10, width=1, which='both')\n",
    "\n",
    "# Set the plot boundaries\n",
    "plt.axis([1e-8, 1e-3, 0, 50])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7-mAgyNmdJCa"
   },
   "source": [
    "## Train the Model\n",
    "\n",
    "To reset the weights, you can simply call the [`set_weights()`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#set_weights) and pass in the saved weights from earlier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MRyv3N36XdNo"
   },
   "outputs": [],
   "source": [
    "# Reset states generated by Keras\n",
    "tf.keras.backend.clear_session()\n",
    "\n",
    "# Reset the weights\n",
    "model.set_weights(init_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MOpeP196AvAU"
   },
   "source": [
    "Then you can set the training parameters and start training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2TuiT7GWRuA5"
   },
   "outputs": [],
   "source": [
    "# Set the learning rate\n",
    "learning_rate = 1e-7\n",
    "\n",
    "# Set the optimizer\n",
    "optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)\n",
    "\n",
    "# Set the training parameters\n",
    "model.compile(loss=tf.keras.losses.Huber(),\n",
    "              optimizer=optimizer,\n",
    "              metrics=[\"mae\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wHt0qDRyPe-N"
   },
   "outputs": [],
   "source": [
    "# Train the model\n",
    "history = model.fit(train_set,epochs=500)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mEGAzXrmBCVd"
   },
   "source": [
    "Training can be a bit unstable especially as the weights start to converge so you may want to visualize it to see if it is still trending down. The earlier epochs might dominate the graph so it's also good to zoom in on the later parts of training to properly observe the parameters. The code below visualizes the `mae` and `loss` for all epochs, and also zooms in at the last 80%."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ok8LjNbbkig4"
   },
   "outputs": [],
   "source": [
    "# Get mae and loss from history log\n",
    "mae=history.history['mae']\n",
    "loss=history.history['loss']\n",
    "\n",
    "# Get number of epochs\n",
    "epochs=range(len(loss)) \n",
    "\n",
    "# Plot mae and loss\n",
    "plot_series(\n",
    "    x=epochs, \n",
    "    y=(mae, loss), \n",
    "    title='MAE and Loss', \n",
    "    xlabel='Epochs',\n",
    "    legend=['MAE', 'Loss']\n",
    "    )\n",
    "\n",
    "# Only plot the last 80% of the epochs\n",
    "zoom_split = int(epochs[-1] * 0.2)\n",
    "epochs_zoom = epochs[zoom_split:]\n",
    "mae_zoom = mae[zoom_split:]\n",
    "loss_zoom = loss[zoom_split:]\n",
    "\n",
    "# Plot zoomed mae and loss\n",
    "plot_series(\n",
    "    x=epochs_zoom, \n",
    "    y=(mae_zoom, loss_zoom), \n",
    "    title='MAE and Loss', \n",
    "    xlabel='Epochs',\n",
    "    legend=['MAE', 'Loss']\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "l4BToQgdS9Lo"
   },
   "source": [
    "## Model Prediction\n",
    "\n",
    "Once training is done, you can generate the model predictions and plot them against the validation set.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HUWdHyyiS3o_"
   },
   "outputs": [],
   "source": [
    "def model_forecast(model, series, window_size, batch_size):\n",
    "    \"\"\"Uses an input model to generate predictions on data windows\n",
    "\n",
    "    Args:\n",
    "      model (TF Keras Model) - model that accepts data windows\n",
    "      series (array of float) - contains the values of the time series\n",
    "      window_size (int) - the number of time steps to include in the window\n",
    "      batch_size (int) - the batch size\n",
    "\n",
    "    Returns:\n",
    "      forecast (numpy array) - array containing predictions\n",
    "    \"\"\"\n",
    "\n",
    "    # Add an axis for the feature dimension of RNN layers\n",
    "    series = tf.expand_dims(series, axis=-1)\n",
    "    \n",
    "    # Generate a TF Dataset from the series values\n",
    "    dataset = tf.data.Dataset.from_tensor_slices(series)\n",
    "\n",
    "    # Window the data but only take those with the specified size\n",
    "    dataset = dataset.window(window_size, shift=1, drop_remainder=True)\n",
    "\n",
    "    # Flatten the windows by putting its elements in a single batch\n",
    "    dataset = dataset.flat_map(lambda w: w.batch(window_size))\n",
    "    \n",
    "    # Create batches of windows\n",
    "    dataset = dataset.batch(batch_size).prefetch(1)\n",
    "    \n",
    "    # Get predictions on the entire dataset\n",
    "    forecast = model.predict(dataset, verbose=0)\n",
    "    \n",
    "    return forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Hsmw8TOUS7K_"
   },
   "outputs": [],
   "source": [
    "# Reduce the original series\n",
    "forecast_series = series[split_time-window_size:-1]\n",
    "\n",
    "# Use helper function to generate predictions\n",
    "forecast = model_forecast(model, forecast_series, window_size, batch_size)\n",
    "\n",
    "# Drop single dimensional axes\n",
    "results = forecast.squeeze()\n",
    "\n",
    "# Plot the results\n",
    "plot_series(time_valid, (x_valid, results))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mR1WxVICDSIE"
   },
   "source": [
    "You can then compute the metrics as usual."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ULKO3JINdqkp"
   },
   "outputs": [],
   "source": [
    "## Compute the MAE and MSE\n",
    "print(tf.keras.metrics.mse(x_valid, results).numpy())\n",
    "print(tf.keras.metrics.mae(x_valid, results).numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QAS7XGJe58hY"
   },
   "source": [
    "## Wrap Up\n",
    "\n",
    "In this lab, you were able to build and train a CNN-RNN model for forecasting. This concludes the series of notebooks on training with synthetic data. In the next labs, you will be looking at a real world time series dataset, particularly sunspot cycles. See you there!\n",
    "\n",
    "If you won't explore the optional exercises below, please uncomment the cell below and run it to free up resources for the next labs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Uncomment the code below if you will not explore the optional section.\n",
    "# # Shutdown the kernel to free up resources. \n",
    "# # Note: You can expect a pop-up when you run this cell. You can safely ignore that and just press `Ok`.\n",
    "\n",
    "# from IPython import get_ipython\n",
    "\n",
    "# k = get_ipython().kernel\n",
    "\n",
    "# k.do_shutdown(restart=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "74P0oB1MrO5K"
   },
   "source": [
    "## Optional - Adding a Callback for Early Stopping\n",
    "\n",
    "In this optional section, you will add a callback to stop training when a metric is met. You already did this in the first course of this specialization and now would be a good time to review."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7E_XS6Kw7lCS"
   },
   "source": [
    "First, you need to prepare a validation set that the model can use and monitor. As shown in the previous lab, you can use the `windowed_dataset()` function to prepare this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aepWmBdueTOB"
   },
   "outputs": [],
   "source": [
    "# Generate data windows from the validation set\n",
    "val_set = windowed_dataset(x_valid, window_size, batch_size, shuffle_buffer_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dEbWDGzU72-h"
   },
   "source": [
    "You can reset the weights of the model or just continue from where you left off."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "TkeXyLYRoJxA"
   },
   "outputs": [],
   "source": [
    "# Uncomment if you want to reset the weights\n",
    "# model.set_weights(init_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HOTxGXlw8Qh-"
   },
   "source": [
    "Next, you will define a callback function that is run every end of an epoch. Inside, you will define the condition to stop training. For this lab, you will set it to stop when the `val_mae` is less than `5.7`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fpajFWODpDIy"
   },
   "outputs": [],
   "source": [
    "class myCallback(tf.keras.callbacks.Callback):\n",
    "  def on_epoch_end(self, epoch, logs={}):\n",
    "    '''\n",
    "    Halts the training when a certain metric is met\n",
    "\n",
    "    Args:\n",
    "      epoch (integer) - index of epoch (required but unused in the function definition below)\n",
    "      logs (dict) - metric results from the training epoch\n",
    "    '''\n",
    "\n",
    "    # Check the validation set MAE\n",
    "    if(logs.get('val_mae') < 5.7):\n",
    "\n",
    "      # Stop if threshold is met\n",
    "      print(\"\\nRequired val MAE is met so cancelling training!\")\n",
    "      self.model.stop_training = True\n",
    "\n",
    "# Instantiate the class\n",
    "callbacks = myCallback()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "abJp8xCR9CUt"
   },
   "source": [
    "Remember to set an appropriate learning rate here. If you're starting from random weights, you may want to use the same rate you used earlier. If you did not reset the weights however, you can use a lower learning rate so the model can learn better. If all goes well, the training will stop before the set 500 epochs are completed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xY9o0I0ieYV-"
   },
   "outputs": [],
   "source": [
    "# Set the learning rate\n",
    "learning_rate = 4e-8\n",
    "\n",
    "# Set the optimizer\n",
    "optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)\n",
    "\n",
    "# Set the training parameters\n",
    "model.compile(loss=tf.keras.losses.Huber(),\n",
    "              optimizer=optimizer,\n",
    "              metrics=[\"mae\"])\n",
    "\n",
    "# Train the model\n",
    "history = model.fit(train_set, epochs=500, validation_data=val_set, callbacks=[callbacks])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hXnLf9409n1E"
   },
   "source": [
    "In practice, you normally have a separate test set to evaluate against unseen data. For this exercise however, the dataset is already very small so let's just use the same validation set just to verify that the results are comparable to the one you got earlier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gUv4QkyaoDxk"
   },
   "outputs": [],
   "source": [
    "# Reduce the original series\n",
    "forecast_series = series[split_time-window_size:-1]\n",
    "\n",
    "# Use helper function to generate predictions\n",
    "forecast = model_forecast(model, forecast_series, window_size, batch_size)\n",
    "\n",
    "# Drop single dimensional axis\n",
    "results = forecast.squeeze()\n",
    "\n",
    "# Plot the results\n",
    "plot_series(time_valid, (x_valid, results))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IjlbjWJG-b-1"
   },
   "source": [
    "The computed metrics here will be slightly different from the one shown in the training output because it has more points to evaluate. Remember that `x_valid` has 461 points that corresponds to `t=1000` to `t=1460`. `val_set` (which is a windowed dataset from `x_valid`), on the other hand, only has 441 points because it cannot generate predictions for `t=1000` to `t=1019` (i.e. windowing will start there)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Z3tKgbO2oEgW"
   },
   "outputs": [],
   "source": [
    "## Compute the MAE and MSE\n",
    "print(tf.keras.metrics.mse(x_valid, results).numpy())\n",
    "print(tf.keras.metrics.mae(x_valid, results).numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the cell below to free up resources for the next lab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note: You can expect a pop-up when you run this cell. You can safely ignore that and just press `Ok`.\n",
    "\n",
    "from IPython import get_ipython\n",
    "\n",
    "k = get_ipython().kernel\n",
    "\n",
    "k.do_shutdown(restart=False)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "C4_W4_Lab_1_LSTM.ipynb",
   "private_outputs": true,
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0rc1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
